{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "torch.Size([4, 3, 6, 6])"
     },
     "metadata": {},
     "execution_count": 1
    }
   ],
   "source": [
    "import time\n",
    "import torch\n",
    "import torch.nn as nn \n",
    "import torch.nn.functional as F \n",
    "\n",
    "import d2lzh as d2l \n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "class Residual(nn.Module):\n",
    "    def __init__(self,in_channels,out_channels,use_1x1conv=False,stride=1):\n",
    "        super(Residual,self).__init__()\n",
    "\n",
    "        self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1,stride=stride)\n",
    "        self.conv2 = nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1)\n",
    "        if use_1x1conv:\n",
    "            self.conv3 = nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=stride)\n",
    "        else:\n",
    "            self.conv3 = None\n",
    "\n",
    "        self.bn1 = nn.BatchNorm2d(out_channels)\n",
    "        self.bn2 = nn.BatchNorm2d(out_channels)\n",
    "\n",
    "    def forward(self,x):\n",
    "        y = F.relu(self.bn1(self.conv1(x)))\n",
    "        y = self.bn2(self.conv2(y))\n",
    "\n",
    "        if self.conv3:\n",
    "            x = self.conv3(x)\n",
    "\n",
    "        return F.relu(y + x)\n",
    "\n",
    "blk = Residual(3,3)\n",
    "x = torch.rand((4,3,6,6))\n",
    "blk(x).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "0  output shape:\t torch.Size([1, 64, 112, 112])\n1  output shape:\t torch.Size([1, 64, 112, 112])\n2  output shape:\t torch.Size([1, 64, 112, 112])\n3  output shape:\t torch.Size([1, 64, 56, 56])\nresidual_block1  output shape:\t torch.Size([1, 64, 56, 56])\nresidual_block2  output shape:\t torch.Size([1, 128, 28, 28])\nresidual_block3  output shape:\t torch.Size([1, 256, 14, 14])\nresidual_block4  output shape:\t torch.Size([1, 512, 7, 7])\nglobal_avg_pool  output shape:\t torch.Size([1, 512, 1, 1])\nfc  output shape:\t torch.Size([1, 10])\n"
    }
   ],
   "source": [
    "net = nn.Sequential(\n",
    "    nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),\n",
    "    nn.BatchNorm2d(64),\n",
    "    nn.ReLU(),\n",
    "    nn.MaxPool2d(kernel_size=3,stride=2,padding=1) \n",
    ")\n",
    "\n",
    "def resnet_block(in_channels,out_channels,num_residuals,first_block=False):\n",
    "    if first_block:\n",
    "        assert in_channels == out_channels\n",
    "\n",
    "    blk = []\n",
    "    for i in range(num_residuals):\n",
    "        if i == 0 and not first_block:\n",
    "            blk.append(Residual(in_channels,out_channels,use_1x1conv=True,stride=2))\n",
    "        else:\n",
    "            blk.append(Residual(out_channels,out_channels))\n",
    "    return nn.Sequential(*blk)\n",
    "\n",
    "net.add_module(\"residual_block1\",resnet_block(64,64,2,first_block=True))\n",
    "net.add_module(\"residual_block2\",resnet_block(64,128,2))\n",
    "net.add_module(\"residual_block3\",resnet_block(128,256,2))\n",
    "net.add_module(\"residual_block4\",resnet_block(256,512,2))\n",
    "\n",
    "net.add_module(\"global_avg_pool\",d2l.GlobalAvgPool2d())\n",
    "net.add_module(\"fc\",nn.Sequential(d2l.FlattenLayer(),nn.Linear(512,10)))\n",
    "\n",
    "x = torch.rand(1,1,224,224)\n",
    "for name,layer in net.named_children():\n",
    "    x = layer(x)\n",
    "    print(name,\" output shape:\\t\",x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "training on  cuda\nepoch 1, loss 0.4373, train acc 0.839,test acc 0.869,time 29.4\nepoch 2, loss 0.1500, train acc 0.889,test acc 0.848,time 23.7\nepoch 3, loss 0.0856, train acc 0.905,test acc 0.894,time 23.8\nepoch 4, loss 0.0580, train acc 0.913,test acc 0.848,time 23.9\nepoch 5, loss 0.0417, train acc 0.922,test acc 0.878,time 23.9\n"
    }
   ],
   "source": [
    "batch_size = 256\n",
    "train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size)\n",
    "\n",
    "lr,num_epochs = 0.001,5\n",
    "optimizer = torch.optim.Adam(net.parameters(),lr=lr)\n",
    "d2l.train_ch5(net,train_iter,test_iter,batch_size,optimizer,device,num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python37764bitpytorchnotebookconda6e7a8693df0d4d92aca09d521275d23a",
   "display_name": "Python 3.7.7 64-bit ('pytorch_notebook': conda)"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}